{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2022 Google LLC\n",
    "\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ScaNN Demo with GloVe Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import h5py\n",
    "import os\n",
    "import requests\n",
    "import tempfile\n",
    "import time\n",
    "\n",
    "import scann"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tempfile.TemporaryDirectory() as tmp:\n",
    "    response = requests.get(\"http://ann-benchmarks.com/glove-100-angular.hdf5\")\n",
    "    loc = os.path.join(tmp, \"glove.hdf5\")\n",
    "    with open(loc, 'wb') as f:\n",
    "        f.write(response.content)\n",
    "    \n",
    "    glove_h5py = h5py.File(loc, \"r\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['distances', 'neighbors', 'test', 'train']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(glove_h5py.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1183514, 100)\n",
      "(10000, 100)\n"
     ]
    }
   ],
   "source": [
    "dataset = glove_h5py['train']\n",
    "queries = glove_h5py['test']\n",
    "print(dataset.shape)\n",
    "print(queries.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create ScaNN searcher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "normalized_dataset = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]\n",
    "# configure ScaNN as a tree - asymmetric hash hybrid with reordering\n",
    "# anisotropic quantization as described in the paper; see README\n",
    "\n",
    "# use scann.scann_ops.build() to instead create a TensorFlow-compatible searcher\n",
    "searcher = scann.scann_ops_pybind.builder(normalized_dataset, 10, \"dot_product\").tree(\n",
    "    num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(\n",
    "    2, anisotropic_quantization_threshold=0.2).reorder(100).build()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_recall(neighbors, true_neighbors):\n",
    "    total = 0\n",
    "    for gt_row, row in zip(true_neighbors, neighbors):\n",
    "        total += np.intersect1d(gt_row, row).shape[0]\n",
    "    return total / true_neighbors.size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ScaNN interface features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall: 0.8999\n",
      "Time: 1.3812487125396729\n"
     ]
    }
   ],
   "source": [
    "# this will search the top 100 of the 2000 leaves, and compute\n",
    "# the exact dot products of the top 100 candidates from asymmetric\n",
    "# hashing to get the final top 10 candidates.\n",
    "start = time.time()\n",
    "neighbors, distances = searcher.search_batched(queries)\n",
    "end = time.time()\n",
    "\n",
    "# we are given top 100 neighbors in the ground truth, so select top 10\n",
    "print(\"Recall:\", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))\n",
    "print(\"Time:\", end - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall: 0.92327\n",
      "Time: 1.8380558490753174\n"
     ]
    }
   ],
   "source": [
    "# increasing the leaves to search increases recall at the cost of speed\n",
    "start = time.time()\n",
    "neighbors, distances = searcher.search_batched(queries, leaves_to_search=150)\n",
    "end = time.time()\n",
    "\n",
    "print(\"Recall:\", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))\n",
    "print(\"Time:\", end - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall: 0.93098\n",
      "Time: 2.2772152423858643\n"
     ]
    }
   ],
   "source": [
    "# increasing reordering (the exact scoring of top AH candidates) has a similar effect.\n",
    "start = time.time()\n",
    "neighbors, distances = searcher.search_batched(queries, leaves_to_search=150, pre_reorder_num_neighbors=250)\n",
    "end = time.time()\n",
    "\n",
    "print(\"Recall:\", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))\n",
    "print(\"Time:\", end - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 10) (10000, 10)\n",
      "(10000, 20) (10000, 20)\n"
     ]
    }
   ],
   "source": [
    "# we can also dynamically configure the number of neighbors returned\n",
    "# currently returns 10 as configued in ScannBuilder()\n",
    "neighbors, distances = searcher.search_batched(queries)\n",
    "print(neighbors.shape, distances.shape)\n",
    "\n",
    "# now returns 20\n",
    "neighbors, distances = searcher.search_batched(queries, final_num_neighbors=20)\n",
    "print(neighbors.shape, distances.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 97478 262700 846101 671078 232287]\n",
      "[2.5518737 2.542952  2.539792  2.5383418 2.519638 ]\n",
      "Latency (ms): 0.7724761962890625\n"
     ]
    }
   ],
   "source": [
    "# we have been exclusively calling batch search so far; the single-query call has the same API\n",
    "start = time.time()\n",
    "neighbors, distances = searcher.search(queries[0], final_num_neighbors=5)\n",
    "end = time.time()\n",
    "\n",
    "print(neighbors)\n",
    "print(distances)\n",
    "print(\"Latency (ms):\", 1000*(end - start))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
